Skip to content

perf: replace Gauss-Jordan with Cholesky precision sampler and pre-compute loop invariants#10

Merged
YuminosukeSato merged 8 commits into
mainfrom
feat/rust-speedup
Mar 23, 2026
Merged

perf: replace Gauss-Jordan with Cholesky precision sampler and pre-compute loop invariants#10
YuminosukeSato merged 8 commits into
mainfrom
feat/rust-speedup

Conversation

@YuminosukeSato
Copy link
Copy Markdown
Owner

Summary

Replace the beta sampling algorithm from Gauss-Jordan matrix inversion + separate multivariate normal sampling (2× O(k³)) with a single Cholesky factorization of the precision matrix (O(k³/6)), matching the approach used by R's bsts package. Additionally, pre-compute loop-invariant values (X^TX matrices and spike-and-slab statistics) before the Gibbs loop.

Motivation

  • Redundant computation: cross_product_matrix(X, T) was called every Gibbs iteration despite X being constant. For k=20, T=400, niter=1000, this wastes 160M floating-point operations.
  • Double cubic cost: invert_matrix() (Gauss-Jordan, O(k³)) followed by sample_mvnormal() (internal Cholesky, O(k³)) = 2× O(k³) per iteration. The new sample_from_precision() does a single Cholesky of the precision matrix.
  • R numerical compatibility: R bsts samples beta via Cholesky of the precision matrix. This change aligns our sampling algorithm with R's approach.
  • Spike-and-slab waste: x_mean and n_j = Σ(x - x̄)² are constants recomputed O(kT) per iteration.

Changes

src/distributions.rs

  • Rename cholesky()cholesky_lower() with doc comments
  • Add forward_solve(), backward_solve_lt(), chol_solve_lower() for triangular system solving
  • Add sample_from_precision() — samples β ~ N(A⁻¹b, σ²A⁻¹) via Cholesky of precision A
  • Delete sample_mvnormal() (no remaining callers)
  • Add 8 unit tests covering k=1, 2×2, near-singular, k=20

src/sampler.rs

  • Pre-compute xtx_static, xtx_seasonal before the Gibbs loop (O(k²T) × 1 instead of × niter)
  • Pre-compute slab_stats (x_mean, n_j per covariate) before the Gibbs loop
  • Replace invert_matrix + sample_mvnormal with sample_from_precision in sample_beta_with_normal_prior
  • Add xtx_precomputed: Option<&[Vec<f64>]> parameter with fallback for backward compatibility
  • Add precomputed_stats: &[(f64, f64)] parameter to sample_spike_and_slab
  • Remove dead loop in n_j guard (residual -= x * 0.0)
  • Delete invert_matrix(), scale_matrix(), test_invert_identity (no remaining callers)

Cargo.toml

  • Add [profile.release] with lto = "thin" and codegen-units = 1

tests/test_rust_speedup.py (new)

  • 12 integration tests covering correctness, determinism, spike-and-slab, k=20 numerical stability, speed benchmarks, and seasonal regression

Benchmark Results

Config Before After Speedup
k=5, T=200, niter=500 0.010s 0.006s ~1.7×
k=10, T=300, niter=500 0.010s 0.011s ~1×
k=20, T=400, niter=500 0.030s 0.022s ~1.4×

Absolute times are small because the existing Rust code already uses SIMD (target-cpu=native). The algorithmic improvement (O(k²T) → O(1) XtX, 2×O(k³) → O(k³/6) Cholesky) shows measurable improvement at larger k. The primary benefit is R numerical compatibility.

Test Plan

  • cargo test — 36 Rust unit tests pass (including 8 new Cholesky tests)
  • .venv/bin/pytest tests/ -v — 224 Python tests pass (including 12 new speedup tests)
  • cargo clippy — 0 warnings
  • Determinism verified: same seed produces identical output before and after
  • k=20 benchmark: all outputs finite, no NaN/Inf

…tation

Tests verify sampler output correctness, determinism, spike-and-slab behavior,
many-covariate numerical stability (k=20), speed benchmarks, and seasonal
regression - all passing on current code as baseline before refactoring.
…ling

Replace private cholesky() with public cholesky_lower(), add forward_solve,
backward_solve_lt, chol_solve_lower, and sample_from_precision. Remove
sample_mvnormal (no remaining callers after sampler.rs migration).

sample_from_precision samples beta ~ N(A^{-1}b, sigma2 * A^{-1}) using
a single Cholesky factorization of the precision matrix A, matching the
approach used by R's bsts package. This replaces the previous approach
of explicit Gauss-Jordan inversion + separate mvnormal sampling.
- Pre-compute cross_product_matrix(X, T) once before the loop instead of
  every iteration (eliminates O(k^2 T) per iteration)
- Pre-compute spike-and-slab x_mean and n_j per covariate (O(1) lookup
  instead of O(T) per covariate per iteration)
- Replace invert_matrix + sample_mvnormal with sample_from_precision
  (single Cholesky instead of Gauss-Jordan + second Cholesky)
- Delete invert_matrix, scale_matrix (no remaining callers)
Add [profile.release] with thin LTO, single codegen unit, and panic=abort
to maximize cross-function inlining and reduce binary size.
Remove unused pytest import and fix line length violations.
These are implementation details of sample_from_precision, not part of
the public API. No external callers exist outside distributions.rs.
Reduces public API surface per refactor-cleaner review.
- Remove panic="abort" from [profile.release] to preserve PyO3's panic
  catch mechanism (prevents Python process crash on Rust panic)
- Eliminate xtx pre.clone() by building posterior_precision directly from
  xtx_ref + prior_precision (avoids k*k Vec clone per iteration)
- Change xtx_precomputed type from Option<&Vec<Vec<f64>>> to
  Option<&[Vec<f64>]> (idiomatic Rust, use .as_deref() at call sites)
- Skip xtx_static computation when spike-and-slab is active (coordinate-
  wise sampling does not use XtX)
- Remove extra blank lines left from scale_matrix deletion
- Remove no-op loop in spike-and-slab n_j < 1e-12 guard: beta[j] is 0.0
  so x_col[t] * 0.0 = 0.0 never changes the residual
- Add #[allow(clippy::too_many_arguments)] to sample_state_path to
  eliminate the last remaining clippy warning
@YuminosukeSato YuminosukeSato merged commit d195077 into main Mar 23, 2026
13 checks passed
YuminosukeSato added a commit that referenced this pull request Mar 23, 2026
perf: replace Gauss-Jordan with Cholesky precision sampler and pre-compute loop invariants
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant